package com.example.encrypt_demo.filter;

import com.example.encrypt_demo.util.AesUtil;
import com.example.encrypt_demo.util.RsaUtil;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.util.Enumeration;
import java.util.Map;
import java.util.Vector;

/**
 * @author benjamin_5
 * @Description
 * @date 2024/9/9
 */
public class DecryptParamsWrapper extends HttpServletRequestWrapper {

    private final Logger logger = LogManager.getLogger(DecryptParamsWrapper.class);

    public static final String KEY_NAME = "encryptKey";
    public static final String SIGN_NAME = "sign";
    /**
     * 参数
     */
    private Map<String, String[]> parameterMap;

    HttpServletRequest request;

    public DecryptParamsWrapper(HttpServletRequest request) {
        super(request);
        this.request = request;
        parameterMap = request.getParameterMap();
    }

    /**
     * 获取所有参数名
     *
     * @return 返回所有参数名
     */
    @Override
    public Enumeration<String> getParameterNames() {
        Vector<String> vector = new Vector<>(parameterMap.keySet());
        return vector.elements();
    }

    /**
     * 获取指定参数名的值，如果有重复的参数名，则返回第一个的值
     *
     * @param name 指定参数名
     * @return 指定参数名的值
     */
    @Override
    public String getParameter(String name) {
        String[] results = parameterMap.get(name);
        if (results == null || results.length <= 0) {
            return null;
        } else if ("rows".equals(name) || "page".equals(name)) {
            return results[0];
        } else {
            return decrypt(results[0]);
        }
    }

    /**
     * 获取指定参数名的所有值的数组
     */
    @Override
    public String[] getParameterValues(String name) {
        String[] results = parameterMap.get(name);
        if (results == null || results.length <= 0) {
            return null;
            // 分页参数过滤
        } else if ("rows".equals(name) || "page".equals(name)) {
            return results;
        } else {
            int length = results.length;
            for (int i = 0; i < length; i++) {
                String temp = decrypt(results[i]);
                results[i] = temp;
            }
            return results;
        }
    }

    private String decrypt(String data) {
        try {
            return AesUtil.aesDecrypt(data, decryptKey());
        } catch (Exception e) {
            logger.error("参数解密失败: ", e);
            throw new RuntimeException(e);
        }
    }

    private String decryptKey(){
        String keyEncrypt = this.request.getHeader(KEY_NAME);
        String key = null;
        try {
            key = RsaUtil.rsaDecrypt(keyEncrypt, RsaUtil.PRIVATE_KEY);
            return key;
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException("密钥解密失败");
        }
    }
}
